{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/EvgenyKashin/gan-vis/blob/master/Playing_with_gans.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "from pathlib import Path\n",
    "import shutil\n",
    "import colorsys\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.data as data\n",
    "\n",
    "from IPython.display import clear_output\n",
    "import matplotlib.pyplot as plt\n",
    "plt.ioff()\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Synthetic data distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def two_gaussians_func(size):\n",
    "    a = torch.randn((size // 2, 2)) * 0.1 + 0.4\n",
    "    b = torch.randn((size // 2, 2)) * 0.2 - 0.5\n",
    "    x = torch.cat((a, b))\n",
    "    return x\n",
    "\n",
    "def ring_func(size):\n",
    "    r = torch.rand((size, 1))\n",
    "    x = 2 * torch.cos(r * math.pi * 2)\n",
    "    y = torch.sin(r * math.pi * 2) + 0.5\n",
    "    xy = torch.cat((x, y), axis=1)\n",
    "    xy += torch.randn((size, 2)) * 0.1\n",
    "    return xy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "two_gaussians = two_gaussians_func(200)\n",
    "plt.scatter(two_gaussians[:, 0], two_gaussians[:, 1], c='g', marker='.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ring = ring_func(200)\n",
    "plt.scatter(ring[:, 0], ring[:, 1], c='g', marker='.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomDistribution(data.Dataset):\n",
    "    def __init__(self, distribution_func, size):\n",
    "        super().__init__()\n",
    "        self.data = distribution_func(size)\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        return self.data[index]\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs_multiplier = 2 # 1 / 16, 1, 16, etc\n",
    "batch_size = int(16 * bs_multiplier)\n",
    "dataset_size = 256\n",
    "\n",
    "# two_gaussians_func or ring_func\n",
    "synthetic_dataset = CustomDistribution(ring_func, dataset_size)\n",
    "synthetic_dl = data.DataLoader(synthetic_dataset, batch_size=batch_size, shuffle=True)\n",
    "synthetic_dataset_all = torch.cat(list(synthetic_dl), axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualization functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Most functions use global variables, do not repeat at home"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "animation_ims = Path('animation_ims')\n",
    "noise2G_imgs = Path('noise2G_imgs')\n",
    "\n",
    "n_rows = 2\n",
    "n_cols = 4\n",
    "\n",
    "def clear_folder(folder):\n",
    "    shutil.rmtree(folder, ignore_errors=True)\n",
    "    folder.mkdir(exist_ok=True)\n",
    "\n",
    "def get_grad_norm(model):\n",
    "    total_norm = 0\n",
    "    for p in model.parameters():\n",
    "        if p.grad is not None:\n",
    "            param_norm = p.grad.data.norm(2)\n",
    "            total_norm += param_norm.item() ** 2\n",
    "    total_norm = total_norm ** (1. / 2)\n",
    "    return total_norm\n",
    "\n",
    "def vis_real():\n",
    "    plt.scatter(synthetic_dataset_all[:, 0], synthetic_dataset_all[:, 1], c='black', marker='.')\n",
    "\n",
    "def vis_G():\n",
    "    g_fake = G(val_noise).detach()\n",
    "    plt.scatter(g_fake[:, 0], g_fake[:, 1], c='r', marker='.')\n",
    "\n",
    "def vis_real_fake():\n",
    "    vis_real()\n",
    "    vis_G()\n",
    "    \n",
    "def vis_grad_norms():\n",
    "    plt.plot(d_grad_norms, label='D')\n",
    "    plt.plot(g_grad_norms, label='G')\n",
    "    plt.legend()\n",
    "    plt.title('Grad norms')\n",
    "    \n",
    "def vis_losses():\n",
    "    plt.plot(d_losses_fakes, label='d_fake')\n",
    "    plt.plot(d_losses_reals, label='d_real')\n",
    "    plt.plot(g_losses, label='g')\n",
    "    plt.legend()\n",
    "    plt.title('Losses')\n",
    "    \n",
    "def vis_preds():\n",
    "    plt.plot(real_preds, label='real')\n",
    "    plt.plot(fake_preds, label='fake')\n",
    "    plt.legend()\n",
    "    plt.title('D preds')\n",
    "    \n",
    "def get_colors_from_matrix(x):\n",
    "    x_max, y_max = x.max(axis=0)[0]\n",
    "    x_min, y_min = x.min(axis=0)[0]\n",
    "    cs = []\n",
    "    \n",
    "    for n in x:\n",
    "        h = (n[0] - x_min) / (x_max - x_min) \n",
    "        v = (n[1] - y_min) / (y_max - y_min)\n",
    "        cs.append(colorsys.hsv_to_rgb(h * 0.9 + 0.1, 1,\n",
    "                                      v * 0.9 + 0.1))\n",
    "    return cs\n",
    "    \n",
    "random_matrix_l1, random_matrix_l2 = None, None\n",
    "    \n",
    "def vis_G_intermediate():\n",
    "    global random_matrix_l1, random_matrix_l2\n",
    "\n",
    "    l1, l2, l3 = G.forward_intermediate(val_noise)\n",
    "    l1, l2, l3 = l1.detach(), l2.detach(), l3.detach()\n",
    "    \n",
    "    cs = get_colors_from_matrix(val_noise)\n",
    "    \n",
    "    plt.subplot(n_rows, n_cols, 5)\n",
    "    plt.scatter(val_noise[:, 0], val_noise[:, 1], marker='.',  c=cs)\n",
    "    plt.title('Input noise')\n",
    "    \n",
    "    plt.subplot(n_rows, n_cols, 6)\n",
    "    if random_matrix_l1 is None:\n",
    "        random_matrix_l1 = torch.randn((l1.shape[1], 2))\n",
    "    l1 = torch.mm(l1, random_matrix_l1)\n",
    "        \n",
    "    plt.scatter(l1[:, 0], l1[:, 1], marker='.',  c=cs)\n",
    "    plt.title('Layer 1 activations')\n",
    "    \n",
    "    plt.subplot(n_rows, n_cols, 7)\n",
    "    if random_matrix_l2 is None:\n",
    "        random_matrix_l2 = torch.randn((l2.shape[1], 2))\n",
    "    l2 = torch.mm(l2, random_matrix_l2)\n",
    "        \n",
    "    plt.scatter(l2[:, 0], l2[:, 1], marker='.',  c=cs)\n",
    "    plt.title('Layer 2 activations')\n",
    "\n",
    "    plt.subplot(n_rows, n_cols, 8)\n",
    "    plt.scatter(l3[:, 0], l3[:, 1], marker='.',  c=cs)\n",
    "    plt.title('Layer 3 activations')\n",
    "\n",
    "def vis_D_decision_boundary():\n",
    "    g_fake = G(val_noise).detach()\n",
    "    fake_and_real = torch.cat((synthetic_dataset_all, g_fake))\n",
    "    \n",
    "    x_min, y_min = fake_and_real.min(axis=0)[0]\n",
    "    x_max, y_max = fake_and_real.max(axis=0)[0]\n",
    "\n",
    "    xx, yy = torch.meshgrid(torch.linspace(x_min, x_max), torch.linspace(y_min, y_max))\n",
    "    xx_flat = xx.reshape(xx.numel())\n",
    "    yy_flat = yy.reshape(yy.numel())\n",
    "    \n",
    "    mesh_points = torch.cat((xx_flat.view(1, -1), yy_flat.view(1, -1)), axis=0).T\n",
    "    mesh_pred = D(mesh_points).detach().view(*xx.shape)\n",
    "    \n",
    "    plt.contourf(xx, yy, mesh_pred, cmap='coolwarm')\n",
    "    plt.colorbar()\n",
    "\n",
    "def vis_G_grad():\n",
    "    fake = G(val_noise)\n",
    "    fake_pred = D(fake).mean()\n",
    "    d_fake = torch.autograd.grad(fake_pred, fake)[0]\n",
    "\n",
    "    arrow_mult = 5e1\n",
    "    fake_d = fake.detach()\n",
    "    for i in range(len(fake)):\n",
    "        plt.arrow(fake_d[i, 0], fake_d[i, 1], d_fake[i, 0] * arrow_mult,\n",
    "                  d_fake[i, 1] * arrow_mult, head_width=4e-2, color='g', alpha=0.4, width=5e-4)\n",
    "\n",
    "def vis(ep, with_grad=False):\n",
    "    G.eval()\n",
    "    D.eval()\n",
    "    \n",
    "    f = plt.figure(figsize=(16, 8))\n",
    "    plt.suptitle(f'Epoch: {ep}')\n",
    "    \n",
    "    plt.subplot(n_rows, n_cols, 1)\n",
    "    vis_D_decision_boundary()\n",
    "    vis_real_fake()\n",
    "    if with_grad:\n",
    "        vis_G_grad()\n",
    "\n",
    "    plt.subplot(n_rows, n_cols, 2)\n",
    "    vis_grad_norms()\n",
    "    \n",
    "    plt.subplot(n_rows, n_cols, 3)\n",
    "    vis_losses()\n",
    "    \n",
    "    plt.subplot(n_rows, n_cols, 4)\n",
    "    vis_preds()\n",
    "    \n",
    "    vis_G_intermediate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vis_real()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Generator(nn.Module):\n",
    "    def __init__(self, hidden=8, act=torch.tanh):\n",
    "        super().__init__()\n",
    "        self.act = act\n",
    "        self.l1 = nn.Linear(2, hidden, bias=False)\n",
    "        self.bn1 = nn.BatchNorm1d(hidden)\n",
    "        self.l2 = nn.Linear(hidden, hidden, bias=False)\n",
    "        self.bn2 = nn.BatchNorm1d(hidden)\n",
    "        self.l3 = nn.Linear(hidden, 2)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.act(self.l1(x))\n",
    "        x = self.bn1(x)\n",
    "        \n",
    "        x = self.act(self.l2(x))\n",
    "        x = self.bn2(x)\n",
    "        \n",
    "        x = self.l3(x)\n",
    "        return x\n",
    "    \n",
    "    def forward_intermediate(self, x):\n",
    "        l1 = self.act(self.l1(x))\n",
    "        l1 = self.bn1(l1)\n",
    "        \n",
    "        l2 = self.act(self.l2(l1))\n",
    "        l2 = self.bn2(l2)\n",
    "        \n",
    "        l3 = self.l3(l2)\n",
    "        return l1, l2, l3\n",
    "\n",
    "class Discriminator(nn.Module):\n",
    "    def __init__(self, hidden=8, act=torch.tanh):\n",
    "        super().__init__()\n",
    "        self.act = act\n",
    "        self.l1 = nn.Linear(2, hidden, bias=False)\n",
    "        self.bn1 = nn.BatchNorm1d(hidden)\n",
    "        self.l2 = nn.Linear(hidden, hidden, bias=False)\n",
    "        self.l3 = nn.Linear(hidden, 1)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.act(self.l1(x))\n",
    "        x = self.bn1(x)\n",
    "        \n",
    "        x = self.act(self.l2(x))        \n",
    "        x = torch.sigmoid(self.l3(x))\n",
    "        return x\n",
    "\n",
    "def get_noise(bs, distribution='norm'):\n",
    "    if distribution == 'uniform':\n",
    "        # uniform [-1, 1]\n",
    "        noise = torch.rand(bs, 2) * 2 - 1\n",
    "    elif distribution == 'norm':\n",
    "        # normal N(0, 1)\n",
    "        noise = torch.randn(bs, 2)\n",
    "    else:\n",
    "        raise ValueError('Wrong distribution parameter')\n",
    "    \n",
    "    return noise"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You should probably do a few runs from scratch for success. Each run will be very different."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# lr is the most important parameters\n",
    "g_lr = 1e-2 * bs_multiplier\n",
    "d_lr = 2e-2 * bs_multiplier\n",
    "epochs = 30\n",
    "\n",
    "G = Generator(hidden=16, act=torch.relu) # torch.tanh or torch.relu\n",
    "D = Discriminator(hidden=16, act=torch.relu) # torch.tanh or torch.relu\n",
    "\n",
    "criterion = nn.BCELoss() # nn.BCELoss() or nn.MSELoss() \n",
    "\n",
    "g_optim = optim.SGD(G.parameters(), g_lr, weight_decay=1e-2)\n",
    "d_optim = optim.SGD(D.parameters(), d_lr, weight_decay=1e-2)\n",
    "\n",
    "# g_optim = optim.Adam(G.parameters(), g_lr, betas=(0.5, 0.999))\n",
    "# d_optim = optim.Adam(D.parameters(), d_lr, betas=(0.5, 0.999))\n",
    "\n",
    "# fixed noise for visualization\n",
    "val_noise = get_noise(batch_size * len(synthetic_dl) // 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Weights initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def init_weights(m):\n",
    "#     if type(m) == nn.Linear:\n",
    "#         nn.init.xavier_uniform_(m.weight)\n",
    "#         nn.init.normal_(m.bias, std=0.1)\n",
    "\n",
    "# G.apply(init_weights)\n",
    "# D.apply(init_weights)\n",
    "\n",
    "# list(G.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "vis_D_decision_boundary()\n",
    "vis_real_fake()\n",
    "vis_G_grad()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "d_grad_norms = []\n",
    "g_grad_norms = []\n",
    "\n",
    "d_losses_fakes = []\n",
    "d_losses_reals = []\n",
    "g_losses = []\n",
    "\n",
    "fake_preds = []\n",
    "real_preds = []\n",
    "\n",
    "clear_folder(animation_ims)\n",
    "save_for_animation = False # save images for making gifs\n",
    "\n",
    "for ep in range(epochs):\n",
    "    for i, batch_data in enumerate(synthetic_dl):        \n",
    "        # Train D\n",
    "        D.train()\n",
    "        d_optim.zero_grad()\n",
    "\n",
    "        # on real\n",
    "        real_pred = D(batch_data)\n",
    "        d_loss_real = criterion(real_pred, torch.ones(batch_size, 1))\n",
    "        d_losses_reals.append(d_loss_real.detach())\n",
    "        real_preds.append(real_pred.mean().detach())\n",
    "\n",
    "        # on fake\n",
    "        fake = G(get_noise(batch_size))\n",
    "        fake_pred = D(fake.detach())\n",
    "        d_loss_fake = criterion(fake_pred, torch.zeros(batch_size, 1))\n",
    "        d_losses_fakes.append(d_loss_fake.detach())\n",
    "        fake_preds.append(fake_pred.mean().detach())\n",
    "\n",
    "        d_loss = d_loss_real + d_loss_fake\n",
    "        d_loss.backward()\n",
    "        d_grad_norms.append(get_grad_norm(D))\n",
    "        d_optim.step()\n",
    "        \n",
    "        # Train G\n",
    "        D.eval()\n",
    "        G.train()\n",
    "        g_optim.zero_grad()\n",
    "\n",
    "        fake = G(get_noise(batch_size))\n",
    "        fake_pred = D(fake)\n",
    "        g_loss = criterion(fake_pred, torch.ones(batch_size, 1))\n",
    "        g_losses.append(g_loss.detach())\n",
    "\n",
    "        g_loss.backward()\n",
    "        g_grad_norms.append(get_grad_norm(G))\n",
    "        g_optim.step()\n",
    "        \n",
    "        # Save for animation (slow), disable this for experiments\n",
    "        if save_for_animation:\n",
    "            G.eval()\n",
    "            D.eval()\n",
    "            \n",
    "            if i % 3 == 0:\n",
    "                f = plt.figure()\n",
    "                vis_D_decision_boundary()\n",
    "                vis_real_fake()\n",
    "                vis_G_grad()\n",
    "                \n",
    "                current_step = ep * len(synthetic_dl) + i\n",
    "                total_steps = epochs * len(synthetic_dl)\n",
    "                plt.title(f'{int(current_step / total_steps * 100)}%')\n",
    "                f.savefig(f'{animation_ims}/{current_step:05}.jpg')\n",
    "                plt.close()\n",
    "\n",
    "    # Visualization\n",
    "    clear_output(True)\n",
    "    vis(ep, with_grad=True)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Сreating gifs from saved images with ImageMagick"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !apt install imagemagick"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !convert -delay 10 -layers optimize animation_ims/*.jpg anim.gif"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Noise after passing the generator step by step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clear_folder(noise2G_imgs)\n",
    "\n",
    "noise = get_noise(512)\n",
    "cs = get_colors_from_matrix(noise)\n",
    "fake = G(noise)\n",
    "transition_space = np.linspace(noise, fake.detach(), num=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i in range(len(transition_space)):\n",
    "    clear_output(True)\n",
    "    f = plt.figure()\n",
    "    plt.scatter(transition_space[i][:, 0], transition_space[i][:, 1], c=cs, marker='.')\n",
    "    plt.title(f'{int(i / len(transition_space) * 100)}%')\n",
    "    \n",
    "    if save_for_animation:\n",
    "        f.savefig(f'{noise2G_imgs}/{i:03}.jpg')\n",
    "        \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !convert -delay 8 -layers optimize noise2G_imgs/*.jpg anim2.gif"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
